rm(list = ls())
gc()
require(tidyverse)
require(fixest)
require(lubridate)
require(caret)
require(ranger)
require(fastDummies)

setwd('/scratch/jhb362/zilinsky_2023/data/')

# Helper function for temporal transformation
temp.fun <- function(x,type = "lag",n = 1) {
  if(type == "lag") {
    return(lag(x,n = n))
  }
  if(type == "chg") {
    return(x - lag(x,n = n))
  }
  if(type == "pctChg") {
    return((x - lag(x,n = n))/lag(x,n = n))
  }
  if(type == "mav") {
    x2 <- NULL
    for(i in 0:n) {
      x2 <- x2 + lag(x,n = i)
    }
    return((sum(x2))/n)
  }
  if(type == "mavWgt") {
    return(x*.5 + lag(x)*.2 + lag(x,n=2)*.15 + lag(x,n=3)*.1 + lag(x,n=4)*.05)
  }
}

args <- commandArgs(trailingOnly = T)
# args <- c(1,'days','FALSE','chg12') #c(25,1,'months','FALSE','chg12') # months, quarters, years, all
bsInd <- as.numeric(args[1])
per <- as.character(args[2])
dummy <- as.logical(args[3])
temp <- as.character(args[4])
set.seed(bsInd)

cat(Sys.getenv("SLURM_NTASKS_PER_NODE"),'\n')
if(Sys.getenv("SLURM_NTASKS_PER_NODE") != "") {
  ncores <- as.integer(Sys.getenv("SLURM_NTASKS_PER_NODE"))
} else {
  ncores <- 8 #parallel::detectCores()
}

# Most important predictors for economic evaluations
load('./final_01_2024.RData')

class(gal$INT_DATE)

ctyDat <- ctyDat %>%
  rename_at(vars(-matches('stcou|year|date|^DEM|^HLTH|^CRIME')),function(x) paste0('ECON_',gsub('LAU_','',x))) %>% 
  select(-matches('_NA_|all_ind|_wages$'),-ECON_emp,-ECON_lf,-ECON_unemp)

# Prep on the county-level data
# Dropping crime data which is crazy skewed
require(moments)
skews <- sapply(ctyDat %>% select(matches('CRIME_ARRST')) %>% select(-matches('_lb|_ub')),function(x) skewness(x,na.rm=T))
drops <- skews[-which(skews < 6)]

ctyDat <- ctyDat %>% select(-matches(paste(gsub('_pct','',names(drops)),collapse = '|')))


indivCovs <- c('AGECAT','PARTY','RACE','GENDER','EDUC','MARST','ANNUALINC')

econCovs <- colnames(ctyDat %>% select(matches('^ECON_(ur|lfpr|tot_aww|.*lq_emplvl)')))
healthCovs <- colnames(ctyDat %>% select(matches('^HLTH_(Male|Female)_Total')) %>% select(-matches('_lb$|_ub$')))
crimeCovs <- colnames(ctyDat %>% select(matches('^CRIME_')) %>% select(-matches('_lb_|_ub_')))
demogCovs <- colnames(ctyDat %>% select(matches('^DEM_')) %>% select(-matches('_lb$|_ub$')))

gal$ZIPDEM_LFPR = gal$ZIPDEM_LF / gal$ZIPDEM_totpop

zipCovs <- colnames(gal %>% select(matches('^ZIP')) %>% select(-matches('CODE|_LF$')))

ctyDat <- ctyDat %>%
  mutate(DEM_pop_tot = log(DEM_pop_tot))


lgs <- paste0(gsub('_pct','',names(skews[which(skews < 6)])),c('_pct','_lb_pct','_ub_pct'))
scls <- c(econCovs,healthCovs,crimeCovs,demogCovs)
fcts <- c(indivCovs)

ctyDat <- ctyDat %>%
  mutate_at(vars(lgs),function(x) log(x+1))

# Temporal Transformation: okay so since some of these are monthly data and the rest
#   are annual, only apply the transformation to the monthly data. This basically means
#   we treat demographics and crime as influencing opinions due to their LEVELS, while
#   we treat health and economics as influencing opinions due to their CHANGES (although
#   this only matters for when we ask for temporal transformations via the command line.)
# test <- ctyDat %>%
#   select(stcou,year,all_of(c(econCovs,healthCovs,crimeCovs,demogCovs))) %>%
#   group_by(stcou,year) %>%
#   summarise_all(var,na.rm=T)
# 
# test2<-sapply(test,function(x) all(x == 0))
# test2[-which(test2)]
# ctyDat %>%
#   select(stcou,date,year,ECON_ur) %>%
#   distinct() %>%
#   group_by(stcou) %>%
#   arrange(date) %>%
#   mutate(ECON_ur = temp.fun(ECON_ur,type = args[5])) %>%
#   ungroup() %>%
#   filter(stcou != '') %>%
#   arrange(stcou,date)

ctyDat <- ctyDat %>%
  group_by(stcou) %>%
  arrange(date) %>%
  mutate_at(vars(colnames(ctyDat %>% select(matches('ECON_|HLTH_')) %>% 
                            select(-matches('_aww')))),
            function(x) temp.fun(x,type = gsub('\\d+','',temp),
                                 n = as.numeric(str_extract(temp,'\\d+')))) %>% # NB: temp.fun should be transformed IN TERMS OF MONTHS
  ungroup()

dat <- gal %>%
  rename(year = YEAR,wgt = COMB_STATEWT) %>%
  left_join(ctyDat %>% select(year,stcou,date,DEM_pop_tot,all_of(c(econCovs,healthCovs,crimeCovs,demogCovs))))

unique(gal$date)
if(as.character(per) == 'all') {
  dat$time = 'all'
} else {
  dat <- dat %>%
    mutate(time = as.Date(lubridate::round_date(INT_DATE,unit = gsub('s$','',as.character(per)))))
}

dat %>% count(time)

# Permutation test + OLS
load('./gallup/var_lookup.RData')

lookup2 <- c('ID' = 'MOTHERLODE_ID',
             'UNION' = 'D17A',
             'EDUC' = 'EDUCATION',
             'RELIG' = 'D8B',
             'ANNUALINC' = 'INCOME_SUMMARY',
             'MONTHINC' = 'MONTHLY_INCOME',
             'RACE' = 'RACE',
             'SELFEMP' = 'WP10202',
             'AGE' = 'WP1220',
             'MARST' = 'WP1223',
             'OCC' = 'WP1225',
             'GENDER' = 'SC7',
             'IDEO' = 'P20',
             'PARTY' = 'PARTY',
             'HLTHINS' = 'H14',
             'SOLSATCOMP' = 'HWB17',
             'COMMHAP' = 'HWB22',
             'MONWORRY' = 'HWB6',
             'ECON' = 'M30',
             'WATCHSPEND' = 'M91',
             'MAJORPURCH' = 'M92',
             'CUTBACKSPEND' = 'M93',
             'FEELGOODMON' = 'M94',
             'MONWORRYYEST' = 'M95',
             'MONMORETHANENOUGH' = 'M96',
             'ENOUGHMON' = 'M97',
             'FEELBETTERFIN' = 'M97A',
             'TRUMPAPP' = 'P1167',
             'TRUMPAPPLN' = 'P1167F',
             'OBAMAPP' = 'P128',
             'HILFAV' = 'P919A',
             'VAGOVFAV' = 'P919W',
             'NATECONIMP' = 'WP148',
             'LIFETOD' = 'WP16',
             'LIFE5YR' = 'WP18',
             'SOLSAT' = 'WP30',
             'HAPPY' = 'WP6878',
             'WORRY' = 'WP69',
             'SAD' = 'WP70',
             'STRESS' = 'WP71',
             'SWB' = 'WELL_BEING_INDEX')

lookup2 <- data.frame(varsNEW = names(lookup2),vars = lookup2) %>%
  as_tibble()

lookup <- lookup %>%
  left_join(lookup2) %>%
  drop_na(varsNEW) %>%
  select(-vars) %>%
  rename(vars = varsNEW)

lookup <- lookup %>%
  mutate(val_lab = ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 2,'Somewhat disagree',
                          ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 3,'Neutral',
                                 ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 4,'Somewhat agree',
                                        ifelse(vars %in% c('COMMHAP','MONWORRY','SOLSATCOMP') & vals == 9,'REF',val_lab))))) %>%
  bind_rows(data.frame(vars = rep('AGECAT',6),
                       var_lab = 'Age category',
                       vals = 1:6,
                       val_lab = c('LT25','25-34','35-44','45-54','55-64','65Up'))) %>%
  filter(!vars %in% c('TRUMPAPP','ENOUGHMON','HILFAV','SOLSAT','WORRY','SAD','STRESS','HAPPY','SELFEMP','OBAMAPP','VAGOVFAV',
                      'ECON','NATECONIMP','LIFETOD','LIFE5YR','AGE','FEELBETTERFIN','WATCHSPEND','COMMHAP','SOLSATCOMP','MONWORRY',
                      'MONMORETHANENOUGH','MONWORRYYEST','FEELGOODMON','CUTBACKSPEND','MAJORPURCH','TRUMPAPPLN')) %>%
  mutate(val_lab = substr(gsub('_{2,}','_',gsub(' |-|–|\\/|—| ','_',gsub("\\(|\\)|,|\\$|\\.|\\[|\\]|'",'',val_lab))),1,30))

# Reduce complexity of EDUC & INC
lookup <- lookup %>%
  filter(vars != 'EDUC',
         vars != 'ANNUALINC')


dat <- dat %>%
  mutate(EDUC = factor(ifelse(EDUC == 1,'LTHS',
                              ifelse(EDUC == 2,'HSDeg',
                                     ifelse(EDUC %in% c(3:5),'SoCo',
                                            ifelse(EDUC %in% 6:8,'CollDegUp',NA)))),
                       levels = c('LTHS','HSDeg','SoCo','CollDegUp')),
         ANNUALINC = factor(ifelse(ANNUALINC %in% 1:4,'LT_23999',
                                   ifelse(ANNUALINC %in% 5:6,'24000_47999',
                                          ifelse(ANNUALINC %in% 7:8,'48000_89999',
                                                 ifelse(ANNUALINC %in% 9:10,'90000_up','DK_REF')))),
                            levels = c('LT_23999','24000_47999','48000_89999','90000_up','DK_REF')))

for(i in colnames(dat)) {
  if(i %in% unique(lookup$vars)) {
    # stop()
    tmp <- lookup %>%
      filter(vars == i,
             vals %in% unique(dat[[i]]))
    dat[[i]] <- factor(dat[[i]],labels = tmp$val_lab)
  }
}

dat <- dat %>%
  mutate(PARTY = relevel(PARTY,ref = 'Independent_no_lean'))

# Preparing ranger
Xs <- c(indivCovs,econCovs,healthCovs,crimeCovs,demogCovs)

Ys <- c('TRUMPAPP','ENOUGHMON','HILFAV','SOLSAT','WORRY','SAD','STRESS','HAPPY','OBAMAPP','VAGOVFAV',
        'ECON','NATECONIMP','LIFETOD','LIFE5YR','FEELBETTERFIN','WATCHSPEND','COMMHAP','SOLSATCOMP','MONWORRY',
        'MONMORETHANENOUGH','MONWORRYYEST','FEELGOODMON','CUTBACKSPEND','MAJORPURCH','ECONBIN','NATECONIMPBIN')

dat <- dat %>%
  mutate(COMMHAP = as.numeric(COMMHAP)) %>%
  mutate(COMMHAP = ifelse(COMMHAP > 5,NA,COMMHAP)) %>% 
  filter(PARTY != 'REF')

gc()

zz <- Sys.time()
for(Y in Ys) {
  
  if(file.exists(paste0('./results/VIMP_ranger/VIMP_2024_outcome-',Y,'_period-',per,'_bsInd-',bsInd,'_DumFacts-',dummy,'_temp-',temp,'.RData'))) {
    cat(paste0('./results/VIMP_ranger/VIMP_2024_outcome-',Y,'_period-',per,'_bsInd-',bsInd,'_DumFacts-',dummy,'_temp-',temp,'.RData'),'already exists\n')
    next
  }
  
  vimpRes <- NULL
  for(d in unique(dat$time)) {
    
    if(class(d) == 'numeric') {
      d <- as.Date(d,origin = '1970-01-01')
    }
    
    tmpDat <- dat %>%
      filter(time %in% d) %>%
      select(Y,Xs) %>% drop_na()
    
    
    if(nrow(tmpDat) == 0) { next }
    
    tmpDat <- tmpDat %>%
      sample_n(size = round(.63*nrow(tmpDat)),replace = T)

    if(dummy) {
      tmpDat <- fastDummies::dummy_cols(tmpDat,remove_selected_columns = T,remove_first_dummy = T)
    }
    
    pp <- preProcess(tmpDat %>% select(-Y),method = c("center","scale"))
    ppX <- try(predict(pp,tmpDat %>% select(-Y)))
    
    if(class(ppX)[1] == "try-error") { stop() }
    
    # Full permutation test
    m <- ranger(formula = Y ~ .,
                data = cbind(ppX,Y = tmpDat[[Y]]),
                importance = 'permutation',
                seed = 123,
                num.threads = ncores)
    if(dummy) {
      fcts2 <- sapply(ppX,function(x) length(unique(x)))
      fcts3 <- names(fcts2[which(fcts2 < 5)])
      counts <- tmpDat %>%
        summarise_at(vars(fcts3),sum) %>%
        gather(vars,n)
      counts <- counts %>%
        bind_rows(data.frame(vars = names(fcts2[-which(names(fcts2) %in% fcts3)]),
                             n = fcts2[-which(names(fcts2) %in% fcts3)]))
    } else {
      counts <- data.frame(vars = names(m$variable.importance),
                           n = nrow(ppX))
    }
    vimpRes <- data.frame(vars = names(m$variable.importance),
                          vimp = m$variable.importance,
                          bsInd = bsInd,
                          outcome = Y,
                          predErr = m$prediction.error,
                          modelType = m$treetype,
                          period = as.character(d),
                          stringsAsFactors = F) %>%
      as_tibble() %>%
      left_join(counts,by = 'vars') %>%
      bind_rows(vimpRes)
    
    prederr <- m$prediction.error
    
    # Testing groups of covariates
    permRes <- NULL
    for(j in c('indivCovs','econCovs','healthCovs','crimeCovs','demogCovs')) {
      covs <- get(j)

      permInds <- sample(1:nrow(tmpDat),size = nrow(tmpDat),replace = F)

      permDat <- ppX %>%
        select(matches(paste(covs,collapse = '|'))) %>%
        select(-matches(paste(Xs[-which(Xs %in% get(j))],collapse = '|'))) %>%
        slice(permInds)

      mTmp <- ranger(formula = Y ~ .,
                     data = cbind(permDat,ppX %>% select(-colnames(permDat)),
                                  Y = tmpDat[[Y]]),
                     importance = 'none',
                     seed = 123,
                     num.threads = ncores)

      permRes <- data.frame(vars = j,
                            vimp = mTmp$prediction.error - prederr,
                            bsInd = bsInd,
                            outcome = Y,
                            predErr = mTmp$prediction.error,
                            modelType = mTmp$treetype,
                            period = as.character(d),
                            stringsAsFactors = F) %>%
        as_tibble() %>%
        bind_rows(permRes)
    }

    vimpRes <- vimpRes %>%
      bind_rows(permRes)
  }
  
  vimpRes %>%
    count(outcome,period)
    cat(Y,'in',round(difftime(Sys.time(),zz,units = 'mins'),2),'minutes\n')
    zz <- Sys.time()
  
  save(vimpRes,
       file = paste0('./results/VIMP_ranger/VIMP_2024_outcome-',Y,'_period-',per,'_bsInd-',bsInd,'_DumFacts-',dummy,'_temp-',temp,'.RData'))
}
# EOF